# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional

import cv2
import mmcv
import numpy as np
import torch
from mmengine.dist import master_only
from mmengine.structures import PixelData
from mmengine.visualization import Visualizer

from mmseg.registry import VISUALIZERS
from mmseg.structures import SegDataSample
from mmseg.utils import get_classes, get_palette


@VISUALIZERS.register_module()
class SegVesselVisualizer(Visualizer):
    """Local Visualizer.

    Args:
        name (str): Name of the instance. Defaults to 'visualizer'.
        image (np.ndarray, optional): the origin image to draw. The format
            should be RGB. Defaults to None.
        vis_backends (list, optional): Visual backend config list.
            Defaults to None.
        save_dir (str, optional): Save file dir for all storage backends.
            If it is None, the backend storage will not save any data.
        classes (list, optional): Input classes for result rendering, as the
            prediction of segmentation model is a segment map with label
            indices, `classes` is a list which includes items responding to the
            label indices. If classes is not defined, visualizer will take
            `cityscapes` classes by default. Defaults to None.
        palette (list, optional): Input palette for result rendering, which is
            a list of color palette responding to the classes. Defaults to None.
        dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_
            visulizer will use the meta information of the dataset i.e. classes
            and palette, but the `classes` and `palette` have higher priority.
            Defaults to None.
        alpha (int, float): The transparency of segmentation mask.
                Defaults to 0.8.

    Examples:
        >>> import numpy as np
        >>> import torch
        >>> from mmengine.structures import PixelData
        >>> from mmseg.structures import SegDataSample
        >>> from mmseg.visualization import SegLocalVisualizer

        >>> seg_local_visualizer = SegLocalVisualizer()
        >>> image = np.random.randint(0, 256,
        ...                     size=(10, 12, 3)).astype('uint8')
        >>> gt_sem_seg_data = dict(data=torch.randint(0, 2, (1, 10, 12)))
        >>> gt_sem_seg = PixelData(**gt_sem_seg_data)
        >>> gt_seg_data_sample = SegDataSample()
        >>> gt_seg_data_sample.gt_sem_seg = gt_sem_seg
        >>> seg_local_visualizer.dataset_meta = dict(
        >>>     classes=('background', 'foreground'),
        >>>     palette=[[120, 120, 120], [6, 230, 230]])
        >>> seg_local_visualizer.add_datasample('visualizer_example',
        ...                         image, gt_seg_data_sample)
        >>> seg_local_visualizer.add_datasample(
        ...                        'visualizer_example', image,
        ...                         gt_seg_data_sample, show=True)
    """  # noqa

    def __init__(self,
                 name: str = 'visualizer',
                 image: Optional[np.ndarray] = None,
                 vis_backends: Optional[Dict] = None,
                 save_dir: Optional[str] = None,
                 classes: Optional[List] = None,
                 palette: Optional[List] = None,
                 dataset_name: Optional[str] = None,
                 alpha: float = 0.8,
                 **kwargs):
        super().__init__(name, image, vis_backends, save_dir, **kwargs)
        self.alpha: float = alpha
        self.set_dataset_meta(palette, classes, dataset_name)

    def _get_center_loc(self, mask: np.ndarray) -> np.ndarray:
        """Get semantic seg center coordinate.

        Args:
            mask: np.ndarray: get from sem_seg
        """
        mask[0,0] = 1
        loc = np.argwhere(mask == 1)

        loc_sort = np.array(
            sorted(loc.tolist(), key=lambda row: (row[0], row[1])))
        y_list = loc_sort[:, 0]
        unique, indices, counts = np.unique(
            y_list, return_index=True, return_counts=True)
        y_loc = unique[counts.argmax()]
        y_most_freq_loc = loc[loc_sort[:, 0] == y_loc]
        center_num = len(y_most_freq_loc) // 2
        x = y_most_freq_loc[center_num][1]
        y = y_most_freq_loc[center_num][0]
        return np.array([x, y])

    def _draw_sem_seg(self,
                      image: np.ndarray,
                      sem_seg: PixelData,
                      classes: Optional[List],
                      palette: Optional[List],
                      mask_index: int,
                      with_labels: Optional[bool] = True) -> np.ndarray:
        """Draw semantic seg of GT or prediction.

        Args:
            image (np.ndarray): The image to draw.
            sem_seg (:obj:`PixelData`): Data structure for pixel-level
                annotations or predictions.
            classes (list, optional): Input classes for result rendering, as
                the prediction of segmentation model is a segment map with
                label indices, `classes` is a list which includes items
                responding to the label indices. If classes is not defined,
                visualizer will take `cityscapes` classes by default.
                Defaults to None.
            palette (list, optional): Input palette for result rendering, which
                is a list of color palette responding to the classes.
                Defaults to None.
            with_labels(bool, optional): Add semantic labels in visualization
                result, Default to True.

        Returns:
            np.ndarray: the drawn image which channel is RGB.
        """
        num_classes = len(classes)
        if sem_seg.data.ndim == 3:
            sem_seg = sem_seg.cpu().data[mask_index]
        else:
            sem_seg = sem_seg.cpu().data
        ids = np.unique(sem_seg)[::-1]
        legal_indices = ids < num_classes
        ids = ids[legal_indices]
        labels = np.array(ids, dtype=np.int64)

        colors = [palette[label] for label in labels]

        mask = np.zeros_like(image, dtype=np.uint8)
        for label, color in zip(labels, colors):
            mask[sem_seg == label, :] = color

        if with_labels:
            font = cv2.FONT_HERSHEY_SIMPLEX
            # (0,1] to change the size of the text relative to the image
            scale = 0.05
            fontScale = min(image.shape[0], image.shape[1]) / (25 / scale)
            fontColor = (255, 255, 255)
            if image.shape[0] < 300 or image.shape[1] < 300:
                thickness = 1
                rectangleThickness = 1
            else:
                thickness = 2
                rectangleThickness = 2
            lineType = 2

            if isinstance(sem_seg[0], torch.Tensor):
                masks = sem_seg[0].numpy() == labels[:, None, None]
            else:
                masks = sem_seg[0] == labels[:, None, None]
            masks = masks.astype(np.uint8)
            for mask_num in range(len(labels)):
                classes_id = labels[mask_num]
                classes_color = colors[mask_num]
                loc = self._get_center_loc(masks[mask_num])
                text = classes[classes_id]
                (label_width, label_height), baseline = cv2.getTextSize(
                    text, font, fontScale, thickness)
                mask = cv2.rectangle(mask, loc,
                                     (loc[0] + label_width + baseline,
                                      loc[1] + label_height + baseline),
                                     classes_color, -1)
                mask = cv2.rectangle(mask, loc,
                                     (loc[0] + label_width + baseline,
                                      loc[1] + label_height + baseline),
                                     (0, 0, 0), rectangleThickness)
                mask = cv2.putText(mask, text, (loc[0], loc[1] + label_height),
                                   font, fontScale, fontColor, thickness,
                                   lineType)
        color_seg = (image * (1 - self.alpha) + mask * self.alpha).astype(
            np.uint8)
        self.set_image(color_seg)
        return color_seg

    def _draw_depth_map(self, image: np.ndarray,
                        depth_map: PixelData) -> np.ndarray:
        """Draws a depth map on a given image.

        This function takes an image and a depth map as input,
        renders the depth map, and concatenates it with the original image.
        Finally, it updates the internal image state of the visualizer with
        the concatenated result.

        Args:
            image (np.ndarray): The original image where the depth map will
                be drawn. The array should be in the format HxWx3 where H is
                the height, W is the width.

            depth_map (PixelData): Depth map to be drawn. The depth map
                should be in the form of a PixelData object. It will be
                converted to a torch tensor if it is a numpy array.

        Returns:
            np.ndarray: The concatenated image with the depth map drawn.

        Example:
            >>> depth_map_data = PixelData(data=torch.rand(1, 10, 10))
            >>> image = np.random.randint(0, 256,
            >>>                           size=(10, 10, 3)).astype('uint8')
            >>> visualizer = SegLocalVisualizer()
            >>> visualizer._draw_depth_map(image, depth_map_data)
        """
        depth_map = depth_map.cpu().data
        if isinstance(depth_map, np.ndarray):
            depth_map = torch.from_numpy(depth_map)
        if depth_map.ndim == 2:
            depth_map = depth_map[None]

        depth_map = self.draw_featmap(depth_map, resize_shape=image.shape[:2])
        out_image = np.concatenate((image, depth_map), axis=0)
        self.set_image(out_image)
        return out_image

    def _calculate_dice_score(self, pred_mask: np.ndarray, gt_mask: np.ndarray) -> float:
        """计算二分类Dice指标。
        
        Args:
            pred_mask (np.ndarray): 预测掩码，二分类 (0=背景, 1=前景)
            gt_mask (np.ndarray): 真实掩码，二分类 (0=背景, 1=前景)
            
        Returns:
            float: Dice分数
        """
        # 确保掩码是二值的
        pred_binary = (pred_mask == 1).astype(np.float32)
        gt_binary = (gt_mask == 1).astype(np.float32)
        
        # 计算交集和并集
        intersection = np.sum(pred_binary * gt_binary)
        pred_sum = np.sum(pred_binary)
        gt_sum = np.sum(gt_binary)
        
        # 如果真实标签中没有前景像素
        if gt_sum == 0:
            return 1.0 if pred_sum == 0 else 0.0
            
        # 计算Dice分数: 2 * |A ∩ B| / (|A| + |B|)
        dice = (2.0 * intersection) / (pred_sum + gt_sum + 1e-10)
        return float(dice)

    def _calculate_class_specific_dice(self, pred_mask: np.ndarray, gt_mask: np.ndarray, 
                                      class_id: int = 1) -> float:
        """计算特定类别的Dice指标。
        
        Args:
            pred_mask (np.ndarray): 预测掩码
            gt_mask (np.ndarray): 真实掩码
            class_id (int): 要计算的类别ID
            
        Returns:
            float: 该类别的Dice分数
        """
        # 提取特定类别的二值掩码
        pred_class = (pred_mask == class_id).astype(np.float32)
        gt_class = (gt_mask == class_id).astype(np.float32)
        
        # 计算交集和并集
        intersection = np.sum(pred_class * gt_class)
        pred_sum = np.sum(pred_class)
        gt_sum = np.sum(gt_class)
        
        # 如果真实标签中没有该类别像素
        if gt_sum == 0:
            return 1.0 if pred_sum == 0 else 0.0
            
        # 计算Dice分数
        dice = (2.0 * intersection) / (pred_sum + gt_sum + 1e-10)
        return float(dice)

    def _draw_text_on_image(self, image: np.ndarray, text: str, 
                           position: tuple = (10, 30), 
                           font_scale: float = 1.0,
                           color: tuple = (255, 255, 255),
                           thickness: int = 2) -> np.ndarray:
        """在图像上绘制文本。
        
        Args:
            image (np.ndarray): 输入图像
            text (str): 要绘制的文本
            position (tuple): 文本位置 (x, y)
            font_scale (float): 字体大小
            color (tuple): 文本颜色 (R, G, B)
            thickness (int): 文本粗细
            
        Returns:
            np.ndarray: 绘制文本后的图像
        """
        image_with_text = image.copy()
        font = cv2.FONT_HERSHEY_SIMPLEX
        
        # 获取文本大小以绘制背景矩形
        (text_width, text_height), baseline = cv2.getTextSize(
            text, font, font_scale, thickness)
        
        # 绘制半透明背景矩形
        x, y = position
        cv2.rectangle(image_with_text, 
                     (x - 5, y - text_height - 5),
                     (x + text_width + 5, y + baseline + 5),
                     (0, 0, 0), -1)
        
        # 绘制文本
        cv2.putText(image_with_text, text, position, font, 
                   font_scale, color, thickness, cv2.LINE_AA)
        
        return image_with_text

    def set_dataset_meta(self,
                         classes: Optional[List] = None,
                         palette: Optional[List] = None,
                         dataset_name: Optional[str] = None) -> None:
        """Set meta information to visualizer.

        Args:
            classes (list, optional): Input classes for result rendering, as
                the prediction of segmentation model is a segment map with
                label indices, `classes` is a list which includes items
                responding to the label indices. If classes is not defined,
                visualizer will take `cityscapes` classes by default.
                Defaults to None.
            palette (list, optional): Input palette for result rendering, which
                is a list of color palette responding to the classes.
                Defaults to None.
            dataset_name (str, optional): `Dataset name or alias <https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/utils/class_names.py#L302-L317>`_
                visulizer will use the meta information of the dataset i.e.
                classes and palette, but the `classes` and `palette` have
                higher priority. Defaults to None.
        """  # noqa
        # Set default value. When calling
        # `SegLocalVisualizer().dataset_meta=xxx`,
        # it will override the default value.
        if dataset_name is None:
            dataset_name = 'cityscapes'
        classes = classes if classes else get_classes(dataset_name)
        palette = palette if palette else get_palette(dataset_name)
        assert len(classes) == len(
            palette), 'The length of classes should be equal to palette'
        self.dataset_meta: dict = {'classes': classes, 'palette': palette}

    @master_only
    def add_datasample(
            self,
            name: str,
            image: np.ndarray,
            data_sample: Optional[SegDataSample] = None,
            draw_gt: bool = True,
            draw_pred: bool = True,
            show: bool = False,
            wait_time: float = 0,
            out_file: Optional[str] = None,
            step: int = 0,
            with_labels: Optional[bool] = True) -> None:

        classes = self.dataset_meta.get('classes', None)
        palette = self.dataset_meta.get('palette', None)

        # 🔧 初始化，避免未定义
        gt_img_data = None
        pred_img_data = None
        gt_img_data_lad = None
        gt_img_data_lcx = None
        pred_img_data_lad = None
        pred_img_data_lcx = None
        
        # 初始化Dice分数变量
        dice_lad = 0.0
        dice_lcx = 0.0

        if draw_gt and data_sample is not None:
            if 'gt_sem_seg' in data_sample:
                assert classes is not None
                gt_img_data_lad = self._draw_sem_seg(
                    image.copy(), data_sample.gt_seg_map_lad, classes, palette, 0, False)
                gt_img_data_lcx = self._draw_sem_seg(
                    image.copy(), data_sample.gt_seg_map_lcx, classes, palette, 0, False)

            if 'gt_depth_map' in data_sample:
                gt_img_data = gt_img_data if gt_img_data is not None else image
                gt_img_data = self._draw_depth_map(
                    gt_img_data, data_sample.gt_depth_map)

        if draw_pred and data_sample is not None:
            if 'pred_seg_map_head0' in data_sample and 'pred_seg_map_head1' in data_sample:
                assert classes is not None
                pred_img_data_lad = self._draw_sem_seg(
                    image.copy(), data_sample.pred_seg_map_head1, classes, palette, 0, False)
                pred_img_data_lcx = self._draw_sem_seg(
                    image.copy(), data_sample.pred_seg_map_head0, classes, palette, 0, False)
                
                # 计算Dice指标
                if draw_gt and 'gt_sem_seg' in data_sample:
                    # 获取预测和真实掩码数据
                    pred_mask_lad = data_sample.pred_seg_map_head1.cpu().data.numpy()
                    pred_mask_lcx = data_sample.pred_seg_map_head0.cpu().data.numpy()
                    gt_mask_lad = data_sample.gt_seg_map_lad.cpu().data.numpy()
                    gt_mask_lcx = data_sample.gt_seg_map_lcx.cpu().data.numpy()
                    
                    # 如果掩码是3维的，取第一个通道
                    if pred_mask_lad.ndim == 3:
                        pred_mask_lad = pred_mask_lad[0]
                    if pred_mask_lcx.ndim == 3:
                        pred_mask_lcx = pred_mask_lcx[0]
                    if gt_mask_lad.ndim == 3:
                        gt_mask_lad = gt_mask_lad[0]
                    if gt_mask_lcx.ndim == 3:
                        gt_mask_lcx = gt_mask_lcx[0]
                    
                    # 计算每个分支的Dice分数 - 二分类场景
                    dice_lad = self._calculate_dice_score(pred_mask_lad, gt_mask_lad)
                    dice_lcx = self._calculate_dice_score(pred_mask_lcx, gt_mask_lcx)
                    
                    # 打印调试信息
                    lad_pred_fg = np.sum(pred_mask_lad == 1)
                    lad_gt_fg = np.sum(gt_mask_lad == 1)
                    lcx_pred_fg = np.sum(pred_mask_lcx == 1)
                    lcx_gt_fg = np.sum(gt_mask_lcx == 1)
                    
                    print(f"LAD - Pred unique: {np.unique(pred_mask_lad)}, GT unique: {np.unique(gt_mask_lad)}")
                    print(f"LAD - Pred foreground pixels: {lad_pred_fg}, GT foreground pixels: {lad_gt_fg}")
                    print(f"LCX - Pred unique: {np.unique(pred_mask_lcx)}, GT unique: {np.unique(gt_mask_lcx)}")
                    print(f"LCX - Pred foreground pixels: {lcx_pred_fg}, GT foreground pixels: {lcx_gt_fg}")
                    print(f"LAD Dice: {dice_lad:.4f}, LCX Dice: {dice_lcx:.4f}")

            if 'pred_depth_map' in data_sample:
                pred_img_data = pred_img_data if pred_img_data is not None else image
                pred_img_data = self._draw_depth_map(
                    pred_img_data, data_sample.pred_depth_map)

        #  拼接时安全检查
        if (gt_img_data_lad is not None and gt_img_data_lcx is not None
                and pred_img_data_lad is not None and pred_img_data_lcx is not None):
            drawn_img = np.concatenate(
                (gt_img_data_lad, gt_img_data_lcx,
                 pred_img_data_lad, pred_img_data_lcx), axis=1)
            
            # 在拼接的图像上添加Dice指标文本
            if dice_lad > 0 or dice_lcx > 0:
                # 计算文本位置
                height, width = drawn_img.shape[:2]
                font_scale = max(0.8, min(width, height) / 800)  # 根据图像大小调整字体
                
                # 在图像顶部添加Dice分数
                avg_dice = (dice_lad + dice_lcx) / 2
                dice_text = f"LAD Dice: {dice_lad:.4f} | LCX Dice: {dice_lcx:.4f} | Avg: {avg_dice:.4f}"
                drawn_img = self._draw_text_on_image(
                    drawn_img, dice_text, 
                    position=(10, int(30 * font_scale)), 
                    font_scale=font_scale,
                    color=(0, 255, 0),  # 绿色文本
                    thickness=max(1, int(2 * font_scale))
                )
                
        elif gt_img_data is not None:
            drawn_img = gt_img_data
        else:
            drawn_img = pred_img_data if pred_img_data is not None else image

        if show:
            self.show(drawn_img, win_name=name, wait_time=wait_time)

        if out_file is not None:
            mmcv.imwrite(mmcv.rgb2bgr(drawn_img), out_file)
        else:
            self.add_image(name, drawn_img, step)